package com.yling.cnn.articleclassification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import  org.nd4j.linalg.primitives.Pair;

import lombok.NonNull;

public class CnnSentenceDataSetIterator implements DataSetIterator {

	public enum UnknownWordHandling {
		RemoveWord, UseUnknownVector
	}
	private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
	private LabeledSentenceProvider sentenceProvider = null;
	private WordVectors wordVectors;
	private TokenizerFactory tokenizerFactory;
	private UnknownWordHandling unknownWordHandling;
	private boolean useNormalizedWordVectors;
	private int minibatchSize;
	private int maxSentenceLength;
	private boolean sentencesAlongHeight;
	private DataSetPreProcessor dataSetPreProcessor;
	private int wordVectorSize;
	private int numClasses;
	private Map<String, Integer> labelClassMap;
	private INDArray unknown;
	private int cursor = 0;

	private CnnSentenceDataSetIterator(Builder builder) {
		this.sentenceProvider = builder.sentenceProvider;
		this.wordVectors = builder.wordVectors;
		this.tokenizerFactory = builder.tokenizerFactory;
		this.unknownWordHandling = builder.unknownWordHandling;
		this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
		this.minibatchSize = builder.minibatchSize;
		this.maxSentenceLength = builder.maxSentenceLength;
		this.sentencesAlongHeight = builder.sentencesAlongHeight;
		this.dataSetPreProcessor = builder.dataSetPreProcessor;
		this.numClasses = this.sentenceProvider.numLabelClasses();
		this.labelClassMap = new HashMap<>();
		int count = 0;
		// First: sort the labels to ensure the same label assignment order (say train
		// vs. test)
		List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
		Collections.sort(sortedLabels);
		for (String s : sortedLabels) {
			this.labelClassMap.put(s, count++);
		}
		if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
			if (useNormalizedWordVectors) {
				wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
			} else {
				wordVectors.getWordVectorMatrix(wordVectors.getUNK());
			}
		}
		this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
	}

	/**
	 * Generally used post training time to load a single sentence for predictions
	 */
	public INDArray loadSingleSentence(String sentence) {
		List<String> tokens = tokenizeSentence(sentence);
		int[] featuresShape = new int[] { 1, 1, 0, 0 };
		if (sentencesAlongHeight) {
			featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
			featuresShape[3] = wordVectorSize;
		} else {
			featuresShape[2] = wordVectorSize;
			featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
		}
		INDArray features = Nd4j.create(featuresShape);
		int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
		for (int i = 0; i < length; i++) {
			INDArray vector = getVector(tokens.get(i));
			INDArrayIndex[] indices = new INDArrayIndex[4];
			indices[0] = NDArrayIndex.point(0);
			indices[1] = NDArrayIndex.point(0);
			if (sentencesAlongHeight) {
				indices[2] = NDArrayIndex.point(i);
				indices[3] = NDArrayIndex.all();
			} else {
				indices[2] = NDArrayIndex.all();
				indices[3] = NDArrayIndex.point(i);
			}
			features.put(indices, vector);
		}
		return features;
	}

	private INDArray getVector(String word) {
		INDArray vector;
		if (unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL) { 
			vector = unknown;
		} else {
			if (useNormalizedWordVectors) {
				vector = wordVectors.getWordVectorMatrixNormalized(word);
			} else {
				vector = wordVectors.getWordVectorMatrix(word);
			}
		}
		return vector;
	}

	private List<String> tokenizeSentence(String sentence) {
		Tokenizer t = tokenizerFactory.create(sentence);
		List<String> tokens = new ArrayList<>();
		while (t.hasMoreTokens()) {
			String token = t.nextToken();
			if (!wordVectors.hasWord(token)) {
				switch (unknownWordHandling) {
				case RemoveWord:
					continue;
				case UseUnknownVector:
					token = UNKNOWN_WORD_SENTINEL;
				}
			}
			tokens.add(token);
		}
		return tokens;
	}

	public Map<String, Integer> getLabelClassMap() {
		return new HashMap<>(labelClassMap);
	}

	@Override
	public List<String> getLabels() {
		// We don't want to just return the list from the LabelledSentenceProvider, as
		// we sorted them earlier to do the
		// String -> Integer mapping
		String[] str = new String[labelClassMap.size()];
		for (Map.Entry<String, Integer> e : labelClassMap.entrySet()) {
			str[e.getValue()] = e.getKey();
		}
		return Arrays.asList(str);
	}

	@Override
	public boolean hasNext() {
		if (sentenceProvider == null) {
			throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
		}
		return sentenceProvider.hasNext();
	}

	@Override
	public DataSet next() {
		return next(minibatchSize);
	}

	@Override
	public DataSet next(int num) {
		if (sentenceProvider == null) {
			throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
		}
		List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
		int maxLength = -1;
		int minLength = Integer.MAX_VALUE; // Track to we know if we can skip mask creation for "all same length" case
		for (int i = 0; i < num && sentenceProvider.hasNext(); i++) {
			Pair<String, String> p = sentenceProvider.nextSentence();
			List<String> tokens = tokenizeSentence(p.getFirst());
			maxLength = Math.max(maxLength, tokens.size());
			tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
		}
		if (maxSentenceLength > 0 && maxLength > maxSentenceLength) {
			maxLength = maxSentenceLength;
		}
		int currMinibatchSize = tokenizedSentences.size();
		INDArray labels = Nd4j.create(currMinibatchSize, numClasses);
		for (int i = 0; i < tokenizedSentences.size(); i++) {
			String labelStr = tokenizedSentences.get(i).getSecond();
			if (!labelClassMap.containsKey(labelStr)) {
				throw new IllegalStateException(
						"Got label \"" + labelStr + "\" that is not present in list of LabeledSentenceProvider labels");
			}
			int labelIdx = labelClassMap.get(labelStr);
			labels.putScalar(i, labelIdx, 1.0);
		}
		int[] featuresShape = new int[4];
		featuresShape[0] = currMinibatchSize;
		featuresShape[1] = 1;
		if (sentencesAlongHeight) {
			featuresShape[2] = maxLength;
			featuresShape[3] = wordVectorSize;
		} else {
			featuresShape[2] = wordVectorSize;
			featuresShape[3] = maxLength;
		}
		INDArray features = Nd4j.create(featuresShape);
		for (int i = 0; i < currMinibatchSize; i++) {
			List<String> currSentence = tokenizedSentences.get(i).getFirst();
			for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) {
				INDArray vector = getVector(currSentence.get(j));
				INDArrayIndex[] indices = new INDArrayIndex[4];
				// TODO REUSE
				indices[0] = NDArrayIndex.point(i);
				indices[1] = NDArrayIndex.point(0);
				if (sentencesAlongHeight) {
					indices[2] = NDArrayIndex.point(j);
					indices[3] = NDArrayIndex.all();
				} else {
					indices[2] = NDArrayIndex.all();
					indices[3] = NDArrayIndex.point(j);
				}
				features.put(indices, vector);
			}
		}
		INDArray featuresMask = null;
		if (minLength != maxLength) {
			featuresMask = Nd4j.create(currMinibatchSize, maxLength);
			for (int i = 0; i < currMinibatchSize; i++) {
				int sentenceLength = tokenizedSentences.get(i).getFirst().size();
				if (sentenceLength >= maxLength) {
					featuresMask.getRow(i).assign(1.0);
				} else {
					featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
				}
			}
		}
		DataSet ds = new DataSet(features, labels, featuresMask, null);
		if (dataSetPreProcessor != null) {
			dataSetPreProcessor.preProcess(ds);
		}
		cursor += ds.numExamples();
		return ds;
	}

	@Override
	public int totalExamples() {
		return sentenceProvider.totalNumSentences();
	}

	@Override
	public int inputColumns() {
		return wordVectorSize;
	}

	@Override
	public int totalOutcomes() {
		return numClasses;
	}

	@Override
	public boolean resetSupported() {
		return true;
	}

	@Override
	public boolean asyncSupported() {
		return true;
	}

	@Override
	public void reset() {
		cursor = 0;
		sentenceProvider.reset();
	}

	@Override
	public int batch() {
		return minibatchSize;
	}

	@Override
	public int cursor() {
		return cursor;
	}

	@Override
	public int numExamples() {
		return totalExamples();
	}

	@Override
	public void setPreProcessor(DataSetPreProcessor preProcessor) {
		this.dataSetPreProcessor = preProcessor;
	}

	@Override
	public DataSetPreProcessor getPreProcessor() {
		return dataSetPreProcessor;
	}

	@Override
	public void remove() {
		throw new UnsupportedOperationException("Not supported");
	}

	public static class Builder {
		private LabeledSentenceProvider sentenceProvider = null;
		private WordVectors wordVectors;
		private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
		private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
		private boolean useNormalizedWordVectors = true;
		private int maxSentenceLength = -1;
		private int minibatchSize = 32;
		private boolean sentencesAlongHeight = true;
		private DataSetPreProcessor dataSetPreProcessor;

		/**
		 * Specify how the (labelled) sentences / documents should be provided
		 */
		public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
			this.sentenceProvider = labeledSentenceProvider;
			return this;
		}

		/**
		 * Specify how the (labelled) sentences / documents should be provided
		 */
		public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List<String> labels) {
			LabelAwareConverter converter = new LabelAwareConverter(iterator, labels);
			return sentenceProvider(converter);
		}

		/**
		 * Specify how the (labelled) sentences / documents should be provided
		 */
		public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List<String> labels) {
			DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator);
			return sentenceProvider(converter, labels);
		}

		/**
		 * Specify how the (labelled) sentences / documents should be provided
		 */
		public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List<String> labels) {
			SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator);
			return sentenceProvider(converter, labels);
		}

		/**
		 * Provide the WordVectors instance that should be used for training
		 */
		public Builder wordVectors(WordVectors wordVectors) {
			this.wordVectors = wordVectors;
			return this;
		}

		/**
		 * The {@link TokenizerFactory} that should be used. Defaults to
		 * {@link DefaultTokenizerFactory}
		 */
		public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
			this.tokenizerFactory = tokenizerFactory;
			return this;
		}

		/**
		 * Specify how unknown words (those that don't have a word vector in the
		 * provided WordVectors instance) should be handled. Default: remove/ignore
		 * unknown words.
		 */
		public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
			this.unknownWordHandling = unknownWordHandling;
			return this;
		}

		/**
		 * Minibatch size to use for the DataSetIterator
		 */
		public Builder minibatchSize(int minibatchSize) {
			this.minibatchSize = minibatchSize;
			return this;
		}

		/**
		 * Whether normalized word vectors should be used. Default: true
		 */
		public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
			this.useNormalizedWordVectors = useNormalizedWordVectors;
			return this;
		}

		/**
		 * Maximum sentence/document length. If sentences exceed this, they will be
		 * truncated to this length by taking the first 'maxSentenceLength' known words.
		 */
		public Builder maxSentenceLength(int maxSentenceLength) {
			this.maxSentenceLength = maxSentenceLength;
			return this;
		}

		/**
		 * If true (default): output features data with shape [minibatchSize, 1,
		 * maxSentenceLength, wordVectorSize]<br>
		 * If false: output features with shape [minibatchSize, 1, wordVectorSize,
		 * maxSentenceLength]
		 */
		public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
			this.sentencesAlongHeight = sentencesAlongHeight;
			return this;
		}

		/**
		 * Optional DataSetPreProcessor
		 */
		public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
			this.dataSetPreProcessor = dataSetPreProcessor;
			return this;
		}

		public CnnSentenceDataSetIterator build() {
			if (wordVectors == null) {
				throw new IllegalStateException(
						"Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
			}
			return new CnnSentenceDataSetIterator(this);
		}
	}
}
